import time
import logging
import os
import random
import torch
import torch.utils.data
from . import base

import pandas as pd 
import numpy as np
import csv, json

from tqdm import tqdm

class SdfLoader(base.Dataset):

    def __init__(
        self,
        data_source, 
        split_file, 
        grid_source=None, 
        samples_per_mesh=16000,
        pc_size=1024,
        modulation_path=None 
    ):

        self.samples_per_mesh = samples_per_mesh  
        self.pc_size = pc_size                    
        self.gt_files = self.get_instance_filenames(data_source, split_file, filter_modulation_path=modulation_path)

        subsample = len(self.gt_files)
        self.gt_files = self.gt_files[0:subsample]

    def __getitem__(self, idx):
        data = np.load(self.gt_files[idx])
        xyz = torch.from_numpy(data['xyz']).float()
        gt_sdf = torch.from_numpy(data['gt_sdf']).float()
        point_cloud = torch.from_numpy(data['point_cloud']).float()
        full_point_cloud = torch.from_numpy(data['full_point_cloud']).float()
        
        atc = data['atc']
        if np.iscomplexobj(atc):  
            atc = 0.0  
        else:
            atc = float(atc)
        seg = torch.from_numpy(data['full_seg'])

        data_dict = {'xyz': xyz, 'gt_sdf': gt_sdf, 'point_cloud': point_cloud, 'atc': atc, 'seg': seg, 'full_point_cloud': full_point_cloud}

        return data_dict

    def __len__(self):
        return len(self.gt_files)